{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# t5-base-fintuned-wikiSQL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use a pipeline as a high-level helper\n",
    "from transformers import AutoModelWithLMHead, AutoTokenizer\n",
    "from datasets import load_dataset, Split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n",
      "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n",
      "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n",
      "/data/rlhf/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/models/auto/modeling_auto.py:1437: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
      "  warnings.warn(\n",
      "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'<pad> SELECT COUNT Model fine tuned FROM table WHERE Base model = BERT</s>'"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"mrm8488/t5-base-finetuned-wikiSQL\")\n",
    "model = AutoModelWithLMHead.from_pretrained(\"mrm8488/t5-base-finetuned-wikiSQL\")\n",
    "\n",
    "def get_sql(query):\n",
    "  input_text = \"translate English to SQL: %s </s>\" % query\n",
    "  features = tokenizer([input_text], return_tensors='pt')\n",
    "\n",
    "  output = model.generate(input_ids=features['input_ids'], \n",
    "               attention_mask=features['attention_mask'])\n",
    "  \n",
    "  return tokenizer.decode(output[0])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14605fcc2dd84c398ece71102506a8d3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading builder script:   0%|          | 0.00/6.57k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "64692f098ea04b6d8858b12c41d92cd2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading metadata:   0%|          | 0.00/2.76k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d4e0266619f1430d8cba26df4d471a3f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/7.80k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset wikisql/default to /data/rlhf/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ae0f4bdc80e64894ae9e5333182b62fb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/26.2M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b6af6938de4e4149b0cbbdc9f91c8d74",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/15878 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a4dd6fc4c1e42a081b4191764a98831",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/8421 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a282a6db0ead436b8e2be68ac41b51a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/56355 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset wikisql downloaded and prepared to /data/rlhf/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d. Subsequent calls will reuse this data.\n"
     ]
    }
   ],
   "source": [
    "valid_dataset = load_dataset('wikisql', split=Split.VALIDATION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'SELECT Position FROM table WHERE School/Club Team = Butler CC (KS)'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_dataset['sql'][i]['human_readable']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "What position does the player who played for butler cc (ks) play?\n",
      "Ideal Answer:\tSELECT Position FROM table WHERE School/Club Team = Butler CC (KS)\n",
      "Generated:\t<pad> SELECT Position FROM table WHERE School/Club Team = Butler CCC (\n",
      "==================================================\n",
      "How many winning drivers were the for the rnd equalling 5?\n",
      "Ideal Answer:\tSELECT COUNT Winning driver FROM table WHERE Rnd = 5\n",
      "Generated:\t<pad> SELECT COUNT Winning driver FROM table WHERE Rnd = 5</s>\n",
      "==================================================\n",
      "List the scores of all games when Miami were listed as the first Semi finalist\n",
      "Ideal Answer:\tSELECT Score FROM table WHERE Semi-Finalist #1 = Miami\n",
      "Generated:\t<pad> SELECT Score FROM table WHERE Semi finalist = Miami</s>\n",
      "==================================================\n",
      " what's the evening gown where state is south dakota\n",
      "Ideal Answer:\tSELECT Evening Gown FROM table WHERE State = South Dakota\n",
      "Generated:\t<pad> SELECT Evening Gown FROM table WHERE State = South Dakota</s>\n",
      "==================================================\n",
      "The common of Chieri has what population density?\n",
      "Ideal Answer:\tSELECT Density (inhabitants/km 2 ) FROM table WHERE Common of = Chieri\n",
      "Generated:\t<pad> SELECT Population density FROM table WHERE Common = chiri</s>\n",
      "==================================================\n",
      "Nagua has the area (km²) of?\n",
      "Ideal Answer:\tSELECT Area (km²) FROM table WHERE Capital = Nagua\n",
      "Generated:\t<pad> SELECT Area (km2) FROM table WHERE City = nagua</s>\n"
     ]
    }
   ],
   "source": [
    "for i in [0, 50, 100, 200, 500, 1000]:\n",
    "    sql = valid_dataset['sql'][i]['human_readable']\n",
    "    question = valid_dataset['question'][i]\n",
    "    generated = get_sql(question)\n",
    "    print('='*50)\n",
    "    print(question)\n",
    "    print('Ideal Answer:\\t' + sql)\n",
    "    print('Generated:\\t' + generated)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "finetune",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
